{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ErOf16E05Dn7"
      },
      "source": [
        "# High-Fidelity Generative Image Compression\n",
        "\n",
        "This colab can be used to compress images using HiFiC. This can also be achieved\n",
        "by running `tfci.py`, as [explained in the README](https://github.com/tensorflow/compression/tree/master/models/hific#running-models-trained-by-us-locally).\n",
        "\n",
        "Please visit [hific.github.io](https://hific.github.io) for more information."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Umer7W0VbITT"
      },
      "source": [
        "## Setup Colab"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LO_MNEQ7Bhbw"
      },
      "outputs": [],
      "source": [
        "# Installs the latest version of TFC compatible with the installed TF version.\n",
        "!pip install tensorflow-compression~=$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\\d+\\.\\d+).*/\\1.0/sg')\n",
        "\n",
        "# Downloads the 'models' directory from Github.\n",
        "![[ -e /tfc ]] || git clone https://github.com/tensorflow/compression /tfc\n",
        "%cd /tfc/models\n",
        "\n",
        "# Checks if tfci.py is available.\n",
        "import tfci\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "op4hPwy_mkPm"
      },
      "source": [
        "## Enabling GPU\n",
        "\n",
        "GPU should be enabled for this colab. If the next cell prints a warning, do the following:\n",
        "- Navigate to Edit→Notebook Settings\n",
        "- select GPU from the Hardware Accelerator drop-down\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x-yLUG_tmo3M"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "if not tf.config.list_physical_devices('GPU'):\n",
        "  print('WARNING: No GPU found. Might be slow!')\n",
        "else:\n",
        "  print('Found GPU.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rQ9-8ZsTf7Hj"
      },
      "source": [
        "## Imports and Definitions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vtd1l70Pf95V"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import zipfile\n",
        "from google.colab import files\n",
        "import collections\n",
        "from PIL import Image\n",
        "from IPython.display import Image as DisplayImage\n",
        "from IPython.display import Javascript\n",
        "from IPython.core.display import display, HTML\n",
        "import tfci\n",
        "import urllib.request\n",
        "\n",
        "tf.get_logger().setLevel('WARN')  # Only show Warnings\n",
        "\n",
        "FILES_DIR = '/content/files'\n",
        "OUT_DIR = '/content/out'\n",
        "DEFAULT_IMAGE_URL = ('https://storage.googleapis.com/hific/clic2020/'\n",
        "                     'images/originals/ad249bba099568403dc6b97bc37f8d74.png')\n",
        "\n",
        "os.makedirs(FILES_DIR, exist_ok=True)\n",
        "os.makedirs(OUT_DIR, exist_ok=True)\n",
        "\n",
        "File = collections.namedtuple('File', ['full_path', 'num_bytes', 'bpp'])\n",
        "\n",
        "def print_html(html):\n",
        "  display(HTML(html + '<br/>'))\n",
        "\n",
        "def make_cell_large():\n",
        "  display(Javascript(\n",
        "      '''google.colab.output.setIframeHeight(0, true, {maxHeight: 5000})'''))\n",
        "\n",
        "def get_default_image(output_dir):\n",
        "  output_path = os.path.join(output_dir, os.path.basename(DEFAULT_IMAGE_URL))\n",
        "  print('Downloading', DEFAULT_IMAGE_URL, '\\n->', output_path)\n",
        "  urllib.request.urlretrieve(DEFAULT_IMAGE_URL, output_path)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Ngs9WvmbTMH"
      },
      "source": [
        "## Load files"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NgtIlL2ADCI2"
      },
      "outputs": [],
      "source": [
        "#@title Setup { vertical-output: false, run: \"auto\", display-mode: \"form\" }\n",
        "#@markdown #### Custom Images\n",
        "#@markdown Tick the following if you want to upload your own images to compress.\n",
        "#@markdown Otherwise, a default image will be used.\n",
        "#@markdown \n",
        "#@markdown **Note**: We support JPG and PNG (without alpha channels).\n",
        "#@markdown \n",
        "\n",
        "upload_custom_images = False #@param {type:\"boolean\", label:\"HI\"}\n",
        "\n",
        "if upload_custom_images:\n",
        "  uploaded = files.upload()\n",
        "  for name, content in uploaded.items():\n",
        "    with open(os.path.join(FILES_DIR, name), 'wb') as fout:\n",
        "      print('Writing', name, '...')\n",
        "      fout.write(content)\n",
        "\n",
        "#@markdown #### Select a model\n",
        "#@markdown Different models target different bitrates.\n",
        "\n",
        "model = 'hific-lo' #@param [\"hific-lo\", \"hific-mi\", \"hific-hi\"]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GYcbc2HupTRD"
      },
      "outputs": [],
      "source": [
        "if 'upload_custom_images' not in locals():\n",
        "  print('ERROR: Please run the previous cell!')\n",
        "  # Setting defaults anyway.\n",
        "  upload_custom_images = False\n",
        "  model = 'hific-lo'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e0C4vMqZsnqA"
      },
      "outputs": [],
      "source": [
        "all_files = os.listdir(FILES_DIR)\n",
        "if not upload_custom_images or not all_files:\n",
        "  print('Downloading default image...')\n",
        "  get_default_image(FILES_DIR)\n",
        "  print()\n",
        "\n",
        "all_files = os.listdir(FILES_DIR)\n",
        "print(f'Got the following files ({len(all_files)}):')\n",
        "\n",
        "for file_name in all_files:\n",
        "  img = Image.open(os.path.join(FILES_DIR, file_name))\n",
        "  w, h = img.size\n",
        "  img = img.resize((w // 15, h // 15))\n",
        "  print('- ' + file_name + ':')\n",
        "  display(img)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "guX3Q_AsTE7-"
      },
      "source": [
        "# Compress images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kd02HOhLBj6e"
      },
      "outputs": [],
      "source": [
        "SUPPORTED_EXT = {'.png', '.jpg'}\n",
        "\n",
        "all_files = os.listdir(FILES_DIR)\n",
        "if not all_files:\n",
        "  raise ValueError(\"Please upload images!\")\n",
        "\n",
        "def get_bpp(image_dimensions, num_bytes):\n",
        "  w, h = image_dimensions\n",
        "  return num_bytes * 8 / (w * h)\n",
        "\n",
        "def has_alpha(img_p):\n",
        "  im = Image.open(img_p)\n",
        "  return im.mode == 'RGBA'\n",
        "\n",
        "all_outputs = []\n",
        "for file_name in all_files:\n",
        "  if os.path.isdir(file_name):\n",
        "    continue\n",
        "  if not any(file_name.endswith(ext) for ext in SUPPORTED_EXT):\n",
        "    print('Skipping', file_name, '...')\n",
        "    continue\n",
        "  full_path = os.path.join(FILES_DIR, file_name)\n",
        "  if has_alpha(full_path):\n",
        "    print('Skipping because of alpha channel:', file_name)\n",
        "    continue\n",
        "  file_name, _ = os.path.splitext(file_name)\n",
        "\n",
        "  compressed_path = os.path.join(OUT_DIR, f'{file_name}_{model}.tfci')\n",
        "  output_path = os.path.join(OUT_DIR, f'{file_name}_{model}.png')\n",
        "  \n",
        "  if os.path.isfile(output_path):\n",
        "    print('Exists already:', output_path)\n",
        "    num_bytes = os.path.getsize(compressed_path)\n",
        "    all_outputs.append(\n",
        "      File(output_path, num_bytes,\n",
        "           get_bpp(Image.open(full_path).size, num_bytes)))\n",
        "    continue\n",
        "\n",
        "  print('Compressing', file_name, 'with', model, '...')\n",
        "  tfci.compress(model, full_path, compressed_path)\n",
        "  num_bytes = os.path.getsize(compressed_path)\n",
        "  print(f'Compressed to {num_bytes} bytes.')\n",
        "\n",
        "  print('Decompressing...')\n",
        "  tfci.decompress(compressed_path, output_path)\n",
        "  \n",
        "  all_outputs.append(\n",
        "      File(output_path, num_bytes,\n",
        "           get_bpp(Image.open(full_path).size, num_bytes)))\n",
        "\n",
        "print('All done!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HQhQQs-CTkgy"
      },
      "source": [
        "# Show output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3nVCPeDnskD8"
      },
      "outputs": [],
      "source": [
        "make_cell_large()  # Larger output window.\n",
        "\n",
        "for file in all_outputs:\n",
        "  print_html('<hr/>')\n",
        "  print(f'Showing {file.full_path} | {file.num_bytes//1000}kB | {file.bpp:.4f}bpp')\n",
        "  display(Image.open(file.full_path))\n",
        "  print_html('<hr/>')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4b-wkBnyrTAR"
      },
      "source": [
        "### Download all compressed images.\n",
        "\n",
        "To download all images, run the following cell.\n",
        "\n",
        "You can also use the _Files_ tab on the left to manually select images.\n",
        "\n",
        "---\n",
        "\n",
        "#### **Note**: the images are saved as PNGs and thus very large. The bitrate used by HiFiC is given in the name."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9BKccvcTpj1k"
      },
      "outputs": [],
      "source": [
        "ZIP = '/content/images.zip'\n",
        "\n",
        "with zipfile.ZipFile(ZIP, 'w') as zf:\n",
        "  for f in all_outputs:\n",
        "    path_with_bpp = f.full_path.replace('.png', f'-{f.bpp:.3f}bpp.png')\n",
        "    zf.write(f.full_path, os.path.basename(path_with_bpp))\n",
        "\n",
        "files.download(ZIP) "
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "colab.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
